{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "03a_w5DhtbNn"
      },
      "outputs": [],
      "source": [
        "from PIL import Image\n",
        "import requests\n",
        "url = \"https://farm4.staticflickr.com/3487/3925656789_1b64654c91_z.jpg\"\n",
        "image = Image.open(requests.get(url, stream=True).raw)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation\n",
        "\n",
        "processor = CLIPSegProcessor.from_pretrained(\"CIDAS/clipseg-rd64-refined\")\n",
        "model = CLIPSegForImageSegmentation.from_pretrained(\"CIDAS/clipseg-rd64-refined\")\n"
      ],
      "metadata": {
        "id": "h8xqlzs-tezH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "prompts = [\"hat\", \"ball\", \"player\", \"red shirt\"]\n",
        "inputs = processor(\n",
        "    text=prompts,\n",
        "    images=[image] * len(prompts),\n",
        "    padding=\"max_length\",\n",
        "    return_tensors=\"pt\",\n",
        ")\n"
      ],
      "metadata": {
        "id": "DPnGAyCcti5N"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "outputs = model(**inputs)\n",
        "preds = outputs.logits.unsqueeze(1)\n"
      ],
      "metadata": {
        "id": "9EbFeclbtm_o"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "_, ax = plt.subplots(1, 5, figsize=(15, 4))\n",
        "[a.axis(\"off\") for a in ax.flatten()]\n",
        "ax[0].imshow(image)\n",
        "[ax[i + 1].imshow(torch.sigmoid(preds[i][0])) for i in range(4)]\n",
        "[ax[i + 1].text(0, -15, prompts[i]) for i in range(4)]\n"
      ],
      "metadata": {
        "id": "fdXKL422tqVd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "url = \"https://upload.wikimedia.org/wikipedia/en/1/1e/Baseball_%28crop%29.jpg\"\n",
        "prompt = Image.open(requests.get(url, stream=True).raw)\n"
      ],
      "metadata": {
        "id": "GnTxOGrXtwMv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "encoded_image = processor(images=[image], return_tensors=\"pt\")\n",
        "encoded_prompt = processor(images=[prompt], return_tensors=\"pt\")\n",
        "outputs = model(**encoded_image, conditional_pixel_values=encoded_prompt.pixel_values)\n"
      ],
      "metadata": {
        "id": "1isv_8MWt12x"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "preds = outputs.logits.unsqueeze(1)\n",
        "preds = torch.transpose(preds, 0, 1)\n",
        "_, ax = plt.subplots(1, 2, figsize=(6, 4))\n",
        "[a.axis(\"off\") for a in ax.flatten()]\n",
        "ax[0].imshow(image)\n",
        "ax[1].imshow(torch.sigmoid(preds[0]))\n"
      ],
      "metadata": {
        "id": "0s8kqCjtt6Fx"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}